# Copyright (c) OpenMMLab. All rights reserved.
import os
import json
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import torch.distributed as dist
from mmengine.dist import infer_launcher, init_dist
from PIL import Image
from xtuner._lite.datasets.utils import apply_exif_orientation

Image.MAX_IMAGE_PIXELS = None

try:
    from petrel_client.client import Client
    from petrel_client.common.config import Config

    has_tcs_loader = True
except ImportError as E:
    print('petrel_client is not installed. Using PIL to load images.')
    has_tcs_loader = False


def calc_image_size(data_item, root):
    try:
        data_item = json.loads(data_item)
    except Exception as e:
        print(f" decoder error {e}: {root}", flush=True)
        h, w = [0, 0]
        image_wh = [[w, h]]
        data_item['image_wh'] = image_wh
        return data_item

    image_path = data_item.get('image')
    if image_path is None:
        return data_item
    try:
        if not isinstance(image_path, list):
            image_path = [image_path]

        image_wh = []
        for path in image_path:
            image_path = os.path.join(root, path)
            if "s3://" in image_path:
                raise NotImplementedError
            else:
                image = Image.open(image_path)
                image = apply_exif_orientation(image)
            w, h = image.size
            image_wh.append([w, h])
    except Exception as e:
        print(f"error {e}: {image_path}", flush=True)
        h, w = [0, 0]
        image_wh = [[w, h]]
    data_item['image_wh'] = image_wh
    return data_item


def save_jsonl(data_list, output_file):
    with open(output_file, 'w', encoding='utf-8') as writer:
        for d in data_list:
            writer.write(json.dumps(d, ensure_ascii=False) + '\n')


if __name__ == '__main__':
    dist_launcher = infer_launcher()
    init_dist(dist_launcher)

    meta_path = 'demo_data/internvl2_mpo.json'
    ds_collections = json.loads(open(meta_path).read())

    _dataset_list = []
    _dataset_lengths = []
    for _, ds_name in enumerate(ds_collections.keys()):
        _data = ds_collections[ds_name]
        if _data['root'].strip() == '':
            continue
        _dataset_list.append([ds_name, ds_collections[ds_name]])

    print('total files', len(_dataset_list))

    if dist.is_available():
        world_size = dist.get_world_size()
        rank = dist.get_rank()
    else:
        world_size = 1
        rank = 0

    _dataset_list = _dataset_list[rank::world_size]

    print(f'[{rank}] Assigned Files: {_dataset_list}')

    for i, _dataset in enumerate(tqdm(_dataset_list)):
        ds_name, _data = _dataset
        root = _data['root']

        with open(_data['annotation'], 'r') as f:
            lines = f.readlines()
        roots = [root] * len(lines)

        with ThreadPoolExecutor(max_workers=32) as executor:
            data = list(
                tqdm(
                    executor.map(calc_image_size, lines, roots),
                    desc='calc_image_size',
                    total=len(lines)))
        save_path = _data['annotation'][:-len('.jsonl')] + f'_with_wh.jsonl'
        try:
            save_jsonl(data, save_path)
        except Exception as e:
            # 可能某些文件没有写权限
            print('Error:', e)
            save_dir = './'
            file_name = save_path.split('/')[-1]
            save_path = os.path.join(save_dir, file_name)
            save_jsonl(data, save_path)

        print(f'[{rank}] {save_path} saved.')
